# ==============================================================================
# -- Communication Module  -----------------------------------------------------
# ==============================================================================

import os
import json
import carla
import queue
import paho.mqtt.client as mqtt
from collections import namedtuple
import threading
import numpy as np
import comm.comm_config as comm_config
import comm.comm_utils as Utils
from comm.message_types import Topic

class LockBuffer(object):
    """
    A buffer with lock
    The content could be Beacon or Data message
    """
    def __init__(self, _content=None):
        self.lock = threading.Lock()
        self.content = _content
        self.age = 0

    def get_content(self, max_age_steps):
        with self.lock:
            self.age += 1  # Increment the age each time content is accessed
            if self.age > max_age_steps:
                self.content = None  # Clear content if it's too old
                return None
            return self.content

    def set_content(self, _content):
        with self.lock:
            self.content = _content
            self.age = 0  # Reset age when new content is set


class Channel(object):
    """
    This is a mqtt client with a buffer to store messages.
    :param topic: the topic that this channel is subscribing to
    :param id: the id of the *vehicle* that this channel belongs to
    :param transceiver: the transceiver that this channel belongs to
    """
    def __init__(self, transceiver, topic):
        self.transceiver = transceiver
        self.id = self.transceiver.id
        self.topic = topic
        self.buffer = dict()  # a dictionary of latest rx content, indexed by peer id
        self.queue = queue.Queue()
        self.client = mqtt.Client(client_id="", #TODO: set it to be vehicle id
                                  clean_session=True, #True: the broker will remove all information about this client when it disconnects; False: the client is a durable client and subscritpion information and queued message will be retained when the client disconnects. (You should use clean_session=False if you need the QoS 2 guarantee of only once delivery)
                                  userdata=self,
                                  )
        self.client.on_connect = on_connect
        self.client.on_message = on_message
        self.client.connect(host=comm_config.Broker, port=1883)
        self.client.loop_start()
        self.client.subscribe(self.topic, qos=comm_config.QoS)
        self.destroyed = False
        self.process = threading.Thread(target=self.processMsgs)
        self.process.start()

    def __del__(self):
        self.destroy()

    def destroy(self):
        self.client.loop_stop()
        self.client.disconnect()
        self.destroyed = True
        self.process.join()
        print("Destoryed Comm {}-{}".format(self.id, self.topic))

    def publish(self, message):
        """
        Publish a message on the topic.
        :param payload: the message to be published
        :param qos: quality of service level
        :param retain: If set to be True, the message will be set as the "last known good"/retained message for the topic
        """
        msg = json.dumps(message, default=lambda o: o.__dict__)
        self.client.publish(topic=self.topic,
                            payload=msg,
                            qos=comm_config.QoS,
                            retain=False,
                            )

    def saveMsgToPeerBuffer(self, message):
        """
        Save message to a buffer that stores the latest message from each peer
        """
        PeerId = str(message.id)
        if PeerId not in self.buffer.keys():
            self.buffer[PeerId] = LockBuffer()
        self.buffer[PeerId].set_content(message)

    def checkDistance(self, message):
        """
        Check if the message is from a peer in range
        """
        return True
        # TODO fix bug here
        if self.id != message.id:
            if str(self.id) in self.buffer.keys():
                mData = self.buffer[str(self.id)].get_content()
                mTrans = carla.Transform(carla.Location(x=mData.x, y=mData.y, z=mData.z))
                peerTrans = carla.Transform(carla.Location(x=message.x, y=message.y, z=message.z))
                if not Utils.reachable_check(mTrans, peerTrans, 2 * comm_config.HalfRadioRange):
                    return False
            else:
                return False
        return True

    def processMsgs(self):
        """
        This function is running in a thread to process messages in the queue
        """
        while True:
            # Check if channel is destroyed
            if self.destroyed:
                break
            
            # Get message from queue
            try:
                message = self.queue.get(timeout=3)
            except queue.Empty:
                continue
            self.queue.task_done()
            
            # Check if message is in range
            InRange = self.checkDistance(message)
            if not InRange:
                continue
            # Save message to buffer
            self.saveMsgToPeerBuffer(message)

def on_connect(client, userdata, flags, rc):
    """
    The callback for when the client receives a CONNACK response from the server.
    :param client: the client instance for this callback
    :param userdata: the prive user data as set in Client() or userdata_set()
    :param flags: response flags sent by the broker
    :param rc: the connection result
    """
    print(str(userdata.id) + " Connected with result code " + str(rc))
    print(str(userdata.id) + " Subscribe to " + userdata.topic)


def on_message(client, userdata, message):
    """
    The callback for when a PUBLISH message is received from the server.
    :param client: the client instance for this callback
    :param userdata: the prive user data as set in Client() or userdata_set()
    :param message: an instance of MQTTMessage. This is a class with members topic, payload, qos, retain.
    """
    m = message.payload.decode("utf-8")
    _content = json.loads(m, object_hook=lambda d: namedtuple(userdata.topic, d.keys())(*d.values()))
    # Log the received message to the buffer in the channel
    userdata.queue.put(_content, block=False)
